"""
functions for merging patent to patent similarity scores between min_year and max_year
2024/06/11
"""

import pandas as pd
import os
import numpy as np
import glob
import tqdm

os.chdir(r'/Volumes/Zihao_SSD2/PatentsView/')


def load_patent_dict(path: str) -> pd.DataFrame:
    """
    load patent dict data [idx, patent_id, patent_year]
    """
    df_patent = pd.read_csv(
        path,
        usecols=["idx", "patent_id", "patent_year"],
        dtype={"idx": "int64", "patent_id": int, "year": "int16"},
    )
    df_patent.columns = ["patent_idx", "patent_id", "patent_year"]

    if df_patent.duplicated(["patent_idx", "patent_id"]).any():
        print("warning. duplicated [patent_idx, patent_id] found in patent_dict.csv!")
        df_patent.drop_duplicates(["patent_idx", "patent_id"], inplace=True)
    return df_patent


def load_sim_files(file_path, year):
    """
    create a list of cosine similarity files given year
    Args:
        file_path: patentsberta/patent_sim_kpss
        year
    """
    data_paths = glob.glob(os.path.join(file_path, f"sim_{year}_topk/*.csv"))
    print(f"num files = {len(data_paths)} for year {year}")
    return data_paths


def merge_years(patent_dict_path: str, sim_file_path: str, min_year: int, max_year: int, top_k: int) -> pd.DataFrame:
    """
    merge all years of processed cosine similarity DataFrames
    Args:
        patent_dict_path:
        sim_file_path:
        min_year: min patent year to merge
        max_year: max patent year to merge
        top_k: keep top k similar patents to save space
    Returns:
        df ['patent_id', 'patent_year', 'cited_patent_id', 'cited_patent_year', 'sim_score', 'patent_idx', 'cited_patent_idx']
    """
    df_patent = load_patent_dict(patent_dict_path)

    df_cited_patent = df_patent.copy()
    df_cited_patent.rename(columns={"patent_idx": "cited_patent_idx"}, inplace=True)

    yr_list = []
    for yr in tqdm.tqdm(range(min_year, max_year + 1)):
        _dir = load_sim_files(sim_file_path, yr)
        print(f"start process and merge {len(_dir)} files for year {yr}")

        df_file = pd.concat(
            [
                pd.read_csv(
                    file,
                    dtype={
                        "patent_idx": int,
                        "cited_patent_idx": int,
                        "sim_score": "float32",
                    },
                )
                for file in _dir
            ],
            ignore_index=True,
        )
        x = len(df_file)
        df_file = df_file.merge(df_patent, on=["patent_idx"]).merge(
            df_cited_patent, on=["cited_patent_idx"]
        )
        print(len(df_file))
        assert x==len(df_file)
        df_file.sort_values(
            ["patent_idx", "sim_score", "cited_patent_idx"], ascending=[True, False, False], inplace=True
        )

        df_file = df_file.groupby("patent_idx").head(top_k).reset_index(drop=True)
        yr_list.append(df_file)
        del df_file, x

    df = pd.concat(yr_list, ignore_index=True)
    del yr_list

    df.sort_values(["patent_idx", "sim_score", "cited_patent_idx"], ascending=[True, False, False], inplace=True)
    df = df.rename(columns={
        'patent_id_x': 'patent_id',
        'patent_year_x': 'patent_year',
        'patent_id_y': 'cited_patent_id',
        'patent_year_y': 'cited_patent_year'
    })[['patent_id', 'patent_year', 'cited_patent_id', 'cited_patent_year', 'sim_score', 'patent_idx', 'cited_patent_idx']]

    return df


def main():
    PATENT_DICT_PATH = ("patentsberta/patent_dict_kpss.csv")
    SIM_FILE_PATH = "patentsberta/patent_sim_kpss"
    START_YEAR = 1981
    END_YEAR = 2015
    TOP_K = 5

    df_sim = merge_years(PATENT_DICT_PATH, SIM_FILE_PATH, START_YEAR, END_YEAR, TOP_K)
    print("nrows of df_sim: ", len(df_sim))
    df_sim.to_csv(f"cleandata/sim_score_{START_YEAR}_{END_YEAR}_top{TOP_K}_kpss.csv", index=False)

if __name__ == "__main__":
    main()
